""" Caclulate Table 4
"""

import os
from io import open
import tempfile
import shutil
import numpy as np
import pandas as pd

# SETTINGS  ===========================================================================================================

FNAME_IN = 'data/sub_info.csv'
FNAME_OUT = 'results/03_hits_data.csv'


# SUPPORTING FUNCTIONS  ===============================================================================================

def clean_output(out_text, file_name, subfolder=""):
    print(file_name, subfolder)
    current = os.getcwd()
    # create temporary directory where work will be temprorily stored
    temp = tempfile.mkdtemp()
    os.chdir(temp)

    # Output stored strings to file
    tex_file_name = file_name + ".tex"
    f = open(tex_file_name, 'w')
    f.write(out_text)
    f.close()

    # Execute latex command
    pdf_command = "pdflatex " + file_name + ".tex"
    os.system(pdf_command)

    # copy pdf file to the current working directory (where this file is)
    pdf_file_name = file_name + ".pdf"
    print(current)
    print(current + subfolder)
    shutil.copy(pdf_file_name, current + subfolder)
    pdf_file_name = file_name + ".tex"
    shutil.copy(pdf_file_name, current + subfolder)
    os.chdir(current)

    # delete the temporary directory (recommended)
    shutil.rmtree(temp)


def beginDocument():
    string = """\documentclass[11pt]{article} \n
\\usepackage{geometry,amsmath,amssymb,datetime,color,everypage,lastpage,multirow,textpos,nicefrac,setspace,tikz} \n
\\usetikzlibrary{calc,arrows,automata,shapes.misc,shapes.arrows,chains,matrix,positioning,scopes,decorations.pathmorphing,shadows} \n
\\usepackage[active,tightpage]{preview} \n
\\PreviewEnvironment{tikzpicture} \n
\\setlength\\PreviewBorder{2mm} \n
\\begin{document} \n
\\begin{tikzpicture} \n"""
    return string


def endDocument():
    string = """
\\end{tikzpicture}\n
\\end{document}"""
    return string


def texLine(x1, y1, x2, y2, line_type):
    if line_type == 1:
        temp = "\\draw (" + str(x1) + "," + str(y1) + ") -- (" + str(x2) + "," + str(y2) + ");"
    if line_type == 2:
        temp = "\\draw [color=gray,dashed] (" + str(x1) + "," + str(y1) + ") -- (" + str(x2) + "," + str(y2) + "); \n"
    return temp


def texNode(text, x1, y1, rotate=",rotate=0", size="\\normalsize ", align=",align=center", width="text width=1cm "):
    temp = "\\node [" + width + rotate + align + "] at (" + str(x1) + "," + str(y1) + ") {" + size + text + "}; \n"

    return temp


def texLine1(x1, y1, x2, y2, line_type=1, color="black"):
    if line_type == 1:
        temp = "\\draw [color=" + color + "] (" + str(x1) + "," + str(y1) + ") -- (" + str(x2) + "," + str(y2) + ");"
    if line_type == 2:
        temp = "\\draw [color=gray,dashed] (" + str(x1) + "," + str(y1) + ") -- (" + str(x2) + "," + str(y2) + "); \n"
    return temp


def texRectangle1(x1, y1, x2, y2, line_type=1, color="black", pattern=1):
    if pattern == 1:
        temp = "\\filldraw[draw=gray,fill=" + color + ", opacity=0.5] (%.02f,%.02f) rectangle (%.02f,%.02f);\n" % (
            x1, y1, x2, y2)

    if pattern == 2:
        temp = "\\draw[pattern=north west lines,pattern color=" + color + ",draw=" + color + "] (%.02f,%.02f) rectangle (%.02f,%.02f);\n" % (
            x1, y1, x2, y2)

    return temp


def texNode1(text, x1, y1, rotate=",rotate=0", size="\\normalsize ", align=",align=center", width="text width=3cm "):
    temp = "\\node [" + width + rotate + align + "] at (" + str(x1) + "," + str(y1) + ") {" + size + text + "}; \n"

    return temp


def my_bootstrap(sample1, stat):
    my_sample = np.array(sample1)
    ns1 = len(my_sample)
    ix = np.random.randint(0, ns1, (10000, ns1))
    samples = my_sample[ix]
    sterr = np.std(stat(samples, 1))

    return sterr


def my_stat_output(vec, x_loc, y_loc, stat="mean", message=",rotate=0"):
    # print "my_stat_output()...stat=",stat
    text_temp = ""

    if stat == "mean":
        s = np.mean(vec)
        serr = my_bootstrap(vec, np.mean)
    else:
        print("NEED TO SPECIFY TEST. See my_stat_output()..")

    text_temp = text_temp + texNode(str(np.round(s, 3)), x_loc, y_loc + .175, message) + \
                texNode("(" + str(np.round(serr, 3)) + ")", x_loc, y_loc - .175, message, "\\footnotesize ")

    return text_temp


def my_sign_output(diff_stat, p_val, x_loc, y_loc, rotate=",rotate=0", size="\\scriptsize", showPvals="False"):
    # print "my_permutation_sig: diff_stat=",diff_stat," p_val=",p_val
    temp_text = ""

    if diff_stat < 0:
        if p_val < .01:
            temp_text = temp_text + texNode("$ \lll $", x_loc, y_loc, rotate, size)
        elif p_val < .05:
            temp_text = temp_text + texNode("$ \ll $", x_loc, y_loc, rotate, size)
        elif p_val < .1:
            temp_text = temp_text + texNode("$ < $", x_loc, y_loc, rotate, size)
        else:
            temp_text = temp_text + texNode("$ \sim $", x_loc, y_loc, rotate, size)

        # print my_test
        # temp_text=temp_text+texNode(str(my_test),x_loc+cell_width/2,y_loc+.2,",rotate=0","\\tiny ")
    else:
        if p_val < .01:
            temp_text = temp_text + texNode("$ \ggg $", x_loc, y_loc, rotate, size)
        elif p_val < .05:
            temp_text = temp_text + texNode("$ \gg $", x_loc, y_loc, rotate, size)
        elif p_val < .1:
            temp_text = temp_text + texNode("$ > $", x_loc, y_loc, rotate, size)
        else:
            temp_text = temp_text + texNode("$ \sim $", x_loc, y_loc, rotate, size)

        # print my_test
    # temp_text=""

    if showPvals == "True":
        if rotate == ",rotate=-90":
            temp_text = temp_text + texNode(str(np.round(p_val, 3)), x_loc + .2, y_loc, rotate, "\\tiny ")
        else:
            temp_text = temp_text + texNode(str(np.round(p_val, 3)), x_loc, y_loc + .2, rotate, "\\tiny ")

    return temp_text


def my_permutation_test(sample1, sample2, n_permut=10000):
    ns1 = len(sample1)
    original_stat = np.mean(sample1) - np.mean(sample2)

    combined = np.append(np.array(sample1), np.array(sample2))
    ns = len(combined)

    ix = np.repeat([np.arange(ns)], n_permut, axis=0)
    for x in ix:
        np.random.shuffle(x)

    samples = combined[ix]
    samples1 = samples[:, :ns1]
    samples2 = samples[:, ns1:]

    distr = np.mean(samples1, 1) - np.mean(samples2, 1)

    count1 = (np.absolute(distr) >= np.absolute(original_stat)).sum()
    p_val1 = np.round(float(count1 + 1) / float(n_permut + 1), 4)

    return [original_stat, p_val1]


# MAIN FUNCTION  =======================================================================================================

def main():
    # Import Subjects' Summary Data

    all_data = pd.read_csv(FNAME_IN)

    print(FNAME_IN, ".head()", all_data.head())

    # Get number of all subjects
    n_subs = all_data.shape[0]

    # Create output variables
    output = {}
    output_labels = ["SubjectID", 'Earnings', 'Prizes', "Doors", "Door1", "Door2", "Starts",
                     "Incentives", "Structures", "Break1", "Door2Given1",
                     "Door2Back", "MedPeriodHit", "MedCycleHit"]

    for i in output_labels:
        output[i] = np.zeros(n_subs)

    # For each subject eid go to corresponding game data .csv file and import the data into output variables
    for i, eid in enumerate(all_data.eid):
        sub_data = all_data[all_data.eid == eid]

        game_data = pd.read_csv('experiments/' + str(eid) + '.csv', header=None)
        game_data = game_data.ix[1:]
        game_data.columns = ['eid', 'config', 'round_no', 'moves_taken', '1st_waitimes',
                             'last_action', 'balance', 'position0', 'position1',
                             'prize_pos0', 'prize_pos1', 'atPrize', 'atStart', 'atDoor']
        game_data['atDoor1'] = (game_data.atDoor == 1) & (game_data.position0 == 4)
        game_data['atDoor2'] = ((game_data.atDoor == 1) & (game_data.position0 == 2)) | (
                (game_data.atDoor == 1) & (game_data.position0 == 0))

        output['SubjectID'][i] = eid
        if len(game_data.balance[game_data.moves_taken == 500]) < 2:

            output['Earnings'][i] = np.min(game_data.balance[game_data.moves_taken == 500])
            output['Prizes'][i] = np.sum(game_data.atPrize)
            output['Doors'][i] = np.sum(game_data.atDoor)
            output['Door1'][i] = np.sum(game_data.atDoor1)
            output['Door2'][i] = np.sum(game_data.atDoor2)
            output['Starts'][i] = np.sum(game_data.atStart)
            output['Incentives'][i] = int((np.min(sub_data.Config) - 1) % 3)
            output['Structures'][i] = int(np.floor((sub_data.Config - 1) / 3))
            output['Break1'][i] = np.min(game_data.moves_taken[(game_data.position0 == 5) & (game_data.position1 == 3)])
            output['Door2Given1'][i] = np.sum(
                ((game_data.atDoor == 1) & (game_data.moves_taken > output['Break1'][i])) |
                ((game_data.moves_taken == 500) & (game_data.position0 == 2) & (
                        game_data.position1 == 3)))
            output['Door2Back'][i] = np.sum((game_data.atDoor == 1) & (game_data.position0 == 0))
            output['MedPeriodHit'][i] = np.median(game_data.moves_taken[(game_data.atDoor == 1)])
            output['MedCycleHit'][i] = np.median(game_data.round_no[(game_data.atDoor == 1)])

        else:
            output['Earnings'][i] = np.min(game_data.balance[game_data.moves_taken == 500])
            output['Prizes'][i] = np.sum(game_data.atPrize)
            output['Doors'][i] = np.sum(game_data.atDoor)
            output['Door1'][i] = np.sum(game_data.atDoor1)
            output['Door2'][i] = np.sum(game_data.atDoor2)
            output['Starts'][i] = np.sum(game_data.atStart)
            output['Incentives'][i] = int((np.min(sub_data.Config) - 1) % 3)
            output['Structures'][i] = int(np.floor((sub_data.Config - 1) / 3))
            output['Break1'][i] = np.min(game_data.moves_taken[(game_data.position0 == 5) & (game_data.position1 == 3)])
            output['Door2Given1'][i] = np.sum(
                ((game_data.atDoor == 1) & (game_data.moves_taken > output['Break1'][i])) |
                ((game_data.moves_taken == 500) & (game_data.position0 == 2) & (
                        game_data.position1 == 3)))
            output['Door2Back'][i] = np.sum((game_data.atDoor == 1) & (game_data.position0 == 0))
            output['MedPeriodHit'][i] = np.median(game_data.moves_taken[(game_data.atDoor == 1)])
            output['MedCycleHit'][i] = np.median(game_data.round_no[(game_data.atDoor == 1)])

    output = pd.DataFrame(output)
    output['ExploratoryActions'] = output.Door2Given1 - output.Door2Back
    output = output[~np.isnan(output.Earnings)]

    all_data = output
    print("Cleaned data:")
    print(all_data.head())

    # Set up table dimensions
    cell_width = 4
    cell_height = 1.25

    # Set up
    dim1_var = 'Incentives'
    dim1_names = ['Gains', 'Losses']
    dim1_tab = [0, 2]  # Limit analysis to these indexes of dim1_names
    dim1_comp = [2,
                 0]  # Compare to these indexes of dim1_names. For example, the first entry in dim1_tab to the first entry in dim1_comp

    dim2_var = 'Structures'
    dim2_names = ['Low', 'High']
    dim2_tab = [0, 3]  # Limit analysis to these indexes of dim2_names
    dim2_comp = [3, 0]  # Compare to these indexes of dim2_names.

    stat_variable = 'ExploratoryActions'

    col_headers = ["\\bf Baseline", "\\bf Breakthrough"]
    row_headers = ["\\bf Gains", "\\bf Losses"]
    temp_text = ""
    for d1i, dim1 in enumerate(dim1_tab):
        for d2i, dim2 in enumerate(dim2_tab):

            if d2i == 0:
                temp_text += texNode(row_headers[d1i], -cell_width, -(d1i + 1) * cell_height, width="text width=4cm ")
            if d1i == 0:
                temp_text = temp_text + texNode(col_headers[d2i], d2i * cell_width, 0,
                                                width="text width=4cm,anchor=mid ")

            filtered_df0 = all_data[(all_data[dim1_var] == dim1) & (all_data[dim2_var] == dim2)]
            temp_text = temp_text + my_stat_output(filtered_df0[stat_variable], d2i * cell_width,
                                                   -(d1i + 1) * cell_height)

            if d1i < 1:
                filtered_df1 = all_data[(all_data[dim1_var] == dim1_comp[d1i]) & (all_data[dim2_var] == dim2_tab[d2i])]
                my_test = my_permutation_test(np.array(filtered_df0[stat_variable]),
                                              np.array(filtered_df1[stat_variable]))
                temp_text = temp_text + my_sign_output(my_test[0], my_test[1], d2i * cell_width,
                                                       -(d1i + 1.5) * cell_height,
                                                       ",rotate=-90,color=black", showPvals="False")
                print(my_test)

            if d2i < 1:
                filtered_df1 = all_data[(all_data[dim1_var] == dim1_tab[d1i]) & (all_data[dim2_var] == dim2_comp[d2i])]
                my_test = my_permutation_test(np.array(filtered_df0[stat_variable]),
                                              np.array(filtered_df1[stat_variable]))
                temp_text = temp_text + my_sign_output(my_test[0], my_test[1], d2i * cell_width + cell_width / 2,
                                                       -(d1i + 1) * cell_height, ",color=black", showPvals="False")
                print(my_test)

    out_text = beginDocument() + temp_text + endDocument()
    clean_output(out_text, "03_table4", "//results//")

    all_data.to_csv(FNAME_OUT, index=False)


# COMMANDLINE  ===============================================================================================================================================

if __name__ == "__main__": main()
